Pytorch中的强化学习

torch.distributions.Categorical()

功能:根据概率分布来产生sample,产生的sample是输入tensor的index
如:
  >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
  >>> m.sample()   # equal probability of 0, 1, 2, 3
  tensor(3)
 
posted @ 2018-09-23 14:54  sbj123456789  阅读(6779)  评论(0编辑  收藏  举报